use datafusion::catalog::{TableFunctionImpl, TableProvider};
use datafusion::common::{Result, ScalarValue, plan_err};
use datafusion::prelude::SessionContext;
use datafusion_expr::Expr;
use std::fmt::Debug;
use std::sync::Arc;

use crate::tpch::TpchTableKind;

macro_rules! define_tpch_udtf_provider {
    ($TABLE_FUNCTION_NAME:ident, $TABLE_FUNCTION_SQL_NAME:ident, $NAME:ident) => {
        #[derive(Debug)]
        pub struct $TABLE_FUNCTION_NAME {}

        impl $TABLE_FUNCTION_NAME {
            /// Returns the name of the table function.
            pub fn name() -> &'static str {
                stringify!($TABLE_FUNCTION_SQL_NAME)
            }

            /// Returns the name of the table generated by the table function
            /// when used in a SQL query.
            pub fn table_name() -> &'static str {
                stringify!($TABLE_FUNCTION_SQL_NAME)
                    .strip_prefix("tpch_")
                    .unwrap_or_else(|| {
                        panic!(
                            "Table function name {} does not start with tpch_",
                            stringify!($TABLE_FUNCTION_SQL_NAME)
                        )
                    })
            }
        }

        impl TableFunctionImpl for $TABLE_FUNCTION_NAME {
            /// Implementation of the UDTF invocation for TPCH table generation
            /// using the [`tpchgen`] library.
            ///
            /// The first argument is a float literal that specifies the scale factor.
            /// The second argument is the part to generate.
            /// The third argument is the number of parts to generate.
            ///
            /// The second and third argument are optional and will default to 1
            /// for both values which tells the generator to generate all parts.
            fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
                let Some(Expr::Literal(ScalarValue::Float64(Some(scale_factor)))) =
                    args.get(0)
                else {
                    return plan_err!("First argument must be a float literal.");
                };

                let mut num_parts = 1;

                // Check if we have more arguments `part` and `num_parts` respectively
                // and if they are i64 literals.
                if args.len() > 1 {
                    // Check if the second argument and third arguments are i32 literals and
                    // greater than 0.
                    let Some(Expr::Literal(ScalarValue::Int64(Some(n)))) = args.get(1)
                    else {
                        return plan_err!("Second argument must be an i64 literal.");
                    };

                    num_parts = (*n).try_into().unwrap();
                }

                let provider = super::source::TpchSource {
                    scale_factor: *scale_factor,
                    num_parts,
                    kind: TpchTableKind::$NAME,
                };

                Ok(Arc::new(provider))
            }
        }
    };
}

define_tpch_udtf_provider!(TpchNation, tpch_nation, Nation);

define_tpch_udtf_provider!(TpchCustomer, tpch_customer, Customer);

define_tpch_udtf_provider!(TpchOrders, tpch_orders, Orders);

define_tpch_udtf_provider!(TpchLineitem, tpch_lineitem, LineItem);

define_tpch_udtf_provider!(TpchPart, tpch_part, Part);

define_tpch_udtf_provider!(TpchPartsupp, tpch_partsupp, PartSupp);

define_tpch_udtf_provider!(TpchSupplier, tpch_supplier, Supplier);

define_tpch_udtf_provider!(TpchRegion, tpch_region, Region);

/// Registers all the TPCH UDTFs in the given session context.
pub fn register_tpch_udtfs(ctx: &SessionContext) -> Result<()> {
    ctx.register_udtf(TpchNation::name(), Arc::new(TpchNation {}));
    ctx.register_udtf(TpchCustomer::name(), Arc::new(TpchCustomer {}));
    ctx.register_udtf(TpchOrders::name(), Arc::new(TpchOrders {}));
    ctx.register_udtf(TpchLineitem::name(), Arc::new(TpchLineitem {}));
    ctx.register_udtf(TpchPart::name(), Arc::new(TpchPart {}));
    ctx.register_udtf(TpchPartsupp::name(), Arc::new(TpchPartsupp {}));
    ctx.register_udtf(TpchSupplier::name(), Arc::new(TpchSupplier {}));
    ctx.register_udtf(TpchRegion::name(), Arc::new(TpchRegion {}));

    Ok(())
}
